"""
PM 服务实现

date: 2023/8/20
author: SiHeng Tang
file: service.py
copyright(c) DFSA Software Developers
此程序不能提供任何担保 WITHOUT WARRANTY OF ANY KIND
"""
import os
import os.path
import sqlite3
import threading

import requests

from core import PApplication, PService
from libpm import core


class SvrTools:
    @staticmethod
    def verify_dir(path: str) -> str:
        if not os.path.isdir(path):
            os.makedirs(path)

        return path

    @staticmethod
    def is_included(top: str, path: str):
        raw_path = os.path.normpath(path)

        return os.path.commonpath([raw_path, top]) == top


###############################
# 数据库
###############################

class _Connection(core.PObject):
    def __init__(self, db_path: str, lock: threading.Lock):
        self.db_path = db_path
        self.connection: sqlite3.Connection = None
        self.cursor: sqlite3.Cursor = None
        self.lock = lock

    def __enter__(self) -> sqlite3.Cursor:
        self.lock.acquire()
        self.connection = sqlite3.connect(self.db_path)
        self.cursor = self.connection.cursor()
        return self.cursor

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.connection.commit()
        self.cursor.close()
        self.connection.close()
        self.lock.release()
        return False


class DBService(core.PService):
    def __init__(self, app: core.PApplication):
        super().__init__(app)

        self.db_dir = SvrTools.verify_dir(self._app.get_top_dir() + '/db')
        self.db_path = f'{self.db_dir}/{self._app.argv["extension_name"]}.db'
        self.db_lock = threading.Lock()

    def connection(self):
        """
        支持with的数据库使用方法，返回游标对象
        example:
            with self.bus.db.connection() as cur:
                cur.execute(sql_str)
        """
        return _Connection(self.db_path, self.db_lock)


###############################
# 文件系统
###############################
class PathOverTopError(core.PError):
    pass


class FileService(core.PService):
    """
    文件抽象
    """

    def __init__(self, app: PApplication):
        super().__init__(app)

        self.file_dir = SvrTools.verify_dir(self._app.get_top_dir() + '/asset' + f'/{self._app.argv["extension_name"]}')

    def open(self, path: str, **kwargs):
        path = f'{self.file_dir}/{path}'

        if SvrTools.is_included(self.file_dir, path):
            return open(path, **kwargs)
        else:
            raise PathOverTopError(path, self.file_dir, info='Target file not in file directory')


###############################
# 网络
###############################
class NetServe(core.PService):
    """
    下载服务
    """

    def __init__(self, app: PApplication):
        super().__init__(app)

        self.requests = requests

    def get_page(self, url: str) -> requests.Response:
        return self.requests.get(url)

    def request(self):
        return self.requests


class SuperFactory(core.PFactory):
    def get_product(self, product: str) -> PService:
        if product == 'db_service':
            return DBService(self._app)
        elif product == 'net_service':
            return NetServe(self._app)
        elif product == 'file_service':
            return FileService(self._app)
